import os

import torch
import torch.distributed as dist
from torch._six import inf
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False).cuda()


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False).cuda()


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        # nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),

    )


def conv_1x1_bn(num_input_channels, num_mid_channel):
    return nn.Sequential(
        conv1x1(num_input_channels, num_mid_channel),
        nn.BatchNorm2d(num_mid_channel),
        nn.ReLU(inplace=True),
        # conv3x3(num_mid_channel, num_mid_channel),
        # nn.BatchNorm2d(num_mid_channel),
        # nn.ReLU(inplace=True),
        # conv1x1(num_mid_channel, num_mid_channel),
    )


class Shake(nn.Module):
    """Convolutional regression for FitNet (feature-map layer)"""

    def __init__(self, feat_t):
        super(Shake, self).__init__()
        #取1，3，5，7，9，11
        #shape[2]=384

        # self.fuse1 = conv_1x1_bn(feat_t[0].shape[2], feat_t[0].shape[2])
        # # self.fuse2 = conv_1x1_bn(feat_t[2].shape[2], feat_t[2].shape[2])
        # # self.fuse3 = conv_1x1_bn(feat_t[4].shape[2], feat_t[4].shape[2])
        # # self.fuse4 = conv_1x1_bn(feat_t[6].shape[2], feat_t[6].shape[2])
        # # self.fuse5 = conv_1x1_bn(feat_t[8].shape[2], feat_t[8].shape[2])
        # # self.fuse6 = conv_1x1_bn(feat_t[10].shape[2], feat_t[10].shape[2])
        #
        # self.fuse7 = conv_bn(feat_t[10].shape[2], feat_t[11].shape[2], 1)

        #cnn teacher
        # self.fuse1 = conv_bn(feat_t[1].shape[2], feat_t[2].shape[2], 1)
        # self.fuse2 = conv_1x1_bn(feat_t[2].shape[2], feat_t[2].shape[2])
        # self.fuse3 = conv_bn(feat_t[2].shape[2], feat_t[3].shape[2], 1)
        # self.fuse4 = conv_1x1_bn(feat_t[3].shape[2], feat_t[3].shape[2])
        # self.fuse5 = conv_bn(feat_t[3].shape[2], feat_t[3].shape[2], 1)
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(feat_t[11].shape[2], 1,bias=False) #MLP shadow

        # FR teacher
        self.fuse1_bn = conv_bn(feat_t[1].shape[2], feat_t[2].shape[2], 1)
        self.fuse2_1x1_bn = conv_1x1_bn(feat_t[2].shape[2], feat_t[2].shape[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(feat_t[11].shape[2], 1,bias=False) #MLP shadow

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, f, weight):
        #torch.Size([12, 196, 384])
        # f:list:12
        B, new_HW, C = f[0].shape
        # t=f[0]
        for i in range(len(f)):
            f[i]=f[i].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
        # t = self.fuse1(f[0].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))
        # x = self.fuse1(f[0].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))+ self.fuse1(f[2].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))+ self.fuse1(f[4].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))+ self.fuse1(f[8].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))+ self.fuse1(f[10].transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))))
        x = f[0]
        for i in range(len(f)-1):
            x = self.fuse1_bn(x) +self.fuse2_1x1_bn(f[i+1])
        x = self.fuse1_bn(x)
        # x = self.fuse5(self.fuse3(self.fuse1(f[0]) + self.fuse2(f[1])) + self.fuse4(f[2]))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = nn.functional.linear(x, weight)
        # x = self.fc(x)
        return x

class ShakeCNN(nn.Module):
    """Convolutional regression for FitNet (feature-map layer)"""

    def __init__(self, feat_cnn):
        super(ShakeCNN, self).__init__()
        #shape[2]=384
        # >> > print(endpoints['reduction_2'].shape)  # torch.Size([1, 24, 56, 56])
        # >> > print(endpoints['reduction_3'].shape)  # torch.Size([1, 40, 28, 28])
        # >> > print(endpoints['reduction_4'].shape)  # torch.Size([1, 112, 14, 14])
        # >> > print(endpoints['reduction_6'].shape)  # torch.Size([1, 1280, 7, 7])
        self.fuse1 = conv_bn(feat_cnn[0].shape[1], feat_cnn[1].shape[1], 2)
        self.fuse2 = conv_1x1_bn(feat_cnn[1].shape[1], feat_cnn[1].shape[1])
        self.fuse3 = conv_bn(feat_cnn[1].shape[1], feat_cnn[2].shape[1], 2)
        self.fuse4 = conv_1x1_bn(feat_cnn[2].shape[1], feat_cnn[2].shape[1])
        self.fuse5 = conv_bn(feat_cnn[2].shape[1], feat_cnn[3].shape[1], 2)
        self.fuse6 = conv_1x1_bn(feat_cnn[3].shape[1], feat_cnn[3].shape[1])
        self.fuse7 = conv_bn(feat_cnn[3].shape[1], feat_cnn[3].shape[1], 1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # FR teacher
        # self.fuse1_bn = conv_bn(feat_t[1].shape[2], feat_t[2].shape[2], 1)
        # self.fuse2_1x1_bn = conv_1x1_bn(feat_t[2].shape[2], feat_t[2].shape[2])
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(feat_t[11].shape[2], 1,bias=False) #MLP shadow

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, f, weight1, weight2, weight3):
        #torch.Size([B, 1280, 7, 7])
        # f:list:4
        x = self.fuse7(self.fuse5(self.fuse3(self.fuse1(f[0]) + self.fuse2(f[1])) + self.fuse4(f[2]))+self.fuse6(f[3]))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = nn.functional.linear(x, weight1)
        x = nn.functional.linear(x, weight2)
        x = nn.functional.linear(x, weight3)
        return x

class ShakeINN(nn.Module):
    """Convolutional regression for FitNet (feature-map layer)"""

    def __init__(self, feat_inn):
        super(ShakeINN, self).__init__()
        # FR teacher
        self.fuse1 = conv_bn(feat_inn[0].shape[1], feat_inn[1].shape[1], 2)
        self.fuse2 = conv_1x1_bn(feat_inn[1].shape[1], feat_inn[1].shape[1])
        self.fuse3 = conv_bn(feat_inn[1].shape[1], feat_inn[2].shape[1], 2)
        self.fuse4 = conv_1x1_bn(feat_inn[2].shape[1], feat_inn[2].shape[1])
        self.fuse5 = conv_bn(feat_inn[2].shape[1], feat_inn[2].shape[1], 1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, f, weight1, weight2, weight3):
        x = self.fuse5(self.fuse3(self.fuse1(f[0]) + self.fuse2(f[1])) + self.fuse4(f[2]))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = nn.functional.linear(x, weight1)
        x = nn.functional.linear(x, weight2)
        x = nn.functional.linear(x, weight3)
        # x = self.fc(x)
        return x

def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
    logger.info(
        f"==============> Resuming form {config.MODEL.RESUME}...................."
    )
    if config.MODEL.RESUME.startswith("https"):
        checkpoint = torch.hub.load_state_dict_from_url(
            config.MODEL.RESUME, map_location="cpu", check_hash=True
        )
    else:
        checkpoint = torch.load(config.MODEL.RESUME, map_location="cpu")
    msg = model.load_state_dict(checkpoint["model"], strict=False)
    logger.info(msg)
    max_plcc = 0.0
    if (
        not config.EVAL_MODE
        and "optimizer" in checkpoint
        and "lr_scheduler" in checkpoint
        and "epoch" in checkpoint
    ):
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        config.defrost()
        config.TRAIN.START_EPOCH = checkpoint["epoch"] + 1
        config.freeze()
        if "scaler" in checkpoint:
            loss_scaler.load_state_dict(checkpoint["scaler"])
        logger.info(
            f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})"
        )
        if "max_plcc" in checkpoint:
            max_plcc = checkpoint["max_plcc"]
        epoched = checkpoint["epoch"]
    del checkpoint
    torch.cuda.empty_cache()
    return max_plcc, epoched


def load_pretrained(config, model, logger):
    logger.info(
        f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......"
    )
    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location="cpu")
    state_dict = checkpoint["model"]

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [
        k for k in state_dict.keys() if "relative_position_index" in k
    ]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete relative_coords_table since we always re-init it
    relative_position_index_keys = [
        k for k in state_dict.keys() if "relative_coords_table" in k
    ]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del state_dict[k]

    # bicubic interpolate relative_position_bias_table if not match
    relative_position_bias_table_keys = [
        k for k in state_dict.keys() if "relative_position_bias_table" in k
    ]
    for k in relative_position_bias_table_keys:
        relative_position_bias_table_pretrained = state_dict[k]
        relative_position_bias_table_current = model.state_dict()[k]
        L1, nH1 = relative_position_bias_table_pretrained.size()
        L2, nH2 = relative_position_bias_table_current.size()
        if nH1 != nH2:
            logger.warning(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                # bicubic interpolate relative_position_bias_table if not match
                S1 = int(L1**0.5)
                S2 = int(L2**0.5)
                relative_position_bias_table_pretrained_resized = (
                    torch.nn.functional.interpolate(
                        relative_position_bias_table_pretrained.permute(1, 0).view(
                            1, nH1, S1, S1
                        ),
                        size=(S2, S2),
                        mode="bicubic",
                    )
                )
                state_dict[k] = relative_position_bias_table_pretrained_resized.view(
                    nH2, L2
                ).permute(1, 0)

    # bicubic interpolate absolute_pos_embed if not match
    absolute_pos_embed_keys = [
        k for k in state_dict.keys() if "absolute_pos_embed" in k
    ]
    for k in absolute_pos_embed_keys:
        # dpe
        absolute_pos_embed_pretrained = state_dict[k]
        absolute_pos_embed_current = model.state_dict()[k]
        _, L1, C1 = absolute_pos_embed_pretrained.size()
        _, L2, C2 = absolute_pos_embed_current.size()
        if C1 != C1:
            logger.warning(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                S1 = int(L1**0.5)
                S2 = int(L2**0.5)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(
                    -1, S1, S1, C1
                )
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(
                    0, 3, 1, 2
                )
                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                    absolute_pos_embed_pretrained, size=(S2, S2), mode="bicubic"
                )
                absolute_pos_embed_pretrained_resized = (
                    absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
                )
                absolute_pos_embed_pretrained_resized = (
                    absolute_pos_embed_pretrained_resized.flatten(1, 2)
                )
                state_dict[k] = absolute_pos_embed_pretrained_resized

    # check classifier, if not match, then re-init classifier to zero
    head_bias_pretrained = state_dict["head.bias"]
    Nc1 = head_bias_pretrained.shape[0]
    Nc2 = model.head.bias.shape[0]
    if Nc1 != Nc2:
        if Nc1 == 21841 and Nc2 == 1000:
            logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
            map22kto1k_path = f"data/map22kto1k.txt"
            with open(map22kto1k_path) as f:
                map22kto1k = f.readlines()
            map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
            state_dict["head.weight"] = state_dict["head.weight"][map22kto1k, :]
            state_dict["head.bias"] = state_dict["head.bias"][map22kto1k]
        else:
            torch.nn.init.constant_(model.head.bias, 0.0)
            torch.nn.init.constant_(model.head.weight, 0.0)
            del state_dict["head.weight"]
            del state_dict["head.bias"]
            logger.warning(
                f"Error in loading classifier head, re-init classifier head to 0"
            )

    msg = model.load_state_dict(state_dict, strict=False)
    logger.warning(msg)

    logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")

    del checkpoint
    torch.cuda.empty_cache()


def save_checkpoint(
    config, epoch, model, max_plcc, optimizer, lr_scheduler, loss_scaler, logger
):
    # save_state = {
    #     "model": model.state_dict(),
    #     # "optimizer": optimizer.state_dict(),
    #     # "lr_scheduler": lr_scheduler.state_dict(),
    #     # "max_plcc": max_plcc,
    #     # "scaler": loss_scaler.state_dict(),
    #     # "epoch": epoch,
    #     # "config": config,
    # }
    save_state = model.state_dict()
    # print(save_state)
    # print(save_state.keys())
    filtered_state_dict = {k : v for k, v in save_state.items() if 'generation' not in k}
    # print(filtered_state_dict.keys())
    save_state = {
        "model": filtered_state_dict,
        # "optimizer": optimizer.state_dict(),
        # "lr_scheduler": lr_scheduler.state_dict(),
        # "max_plcc": max_plcc,
        # "scaler": loss_scaler.state_dict(),
        # "epoch": epoch,
        # "config": config,
    }
    save_path = os.path.join(config.OUTPUT, f"ckpt_epoch_{epoch}.pth")
    logger.info(f"{save_path} saving......")
    torch.save(save_state, save_path ,_use_new_zipfile_serialization=False)
    logger.info(f"{save_path} saved !!!")


def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1.0 / norm_type)
    return total_norm


def auto_resume_helper(output_dir):
    checkpoints = os.listdir(output_dir)
    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith("pth")]
    print(f"All checkpoints founded in {output_dir}: {checkpoints}")
    if len(checkpoints) > 0:
        latest_checkpoint = max(
            [os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime
        )
        print(f"The latest checkpoint founded: {latest_checkpoint}")
        resume_file = latest_checkpoint
    else:
        resume_file = None
    return resume_file


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.0)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(
            torch.stack(
                [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
            ),
            norm_type,
        )
    return total_norm


class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(
        self,
        loss,
        optimizer,
        clip_grad=None,
        parameters=None,
        create_graph=False,
        retain_graph=None,
        update_grad=True,
    ):
        self._scaler.scale(loss).backward(retain_graph=retain_graph,create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(
                    optimizer
                )  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            # norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=20, norm_type=2)
            else:
                self._scaler.unscale_(optimizer)
                norm = ampscaler_get_grad_norm(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)
